机器学习进阶毕业项目

猫狗大战

——张红亮,优达学城

项目背景

猫狗分类问题是计算机视觉领域的经典问题,识别图片中的猫和狗对人类来说,2岁小孩即可轻松完成,但是让计算机完成这一任务,却是曾经机器学习技术难以攻克的一座大山,直到30年前,深度学习之父杰弗里·辛顿将多层神经网络带入机器学习领域,它为近10年来深度学习的发展奠定了基础,使得这个曾经困扰很多机器学习领域实践者多年的问题迎刃而解,近年来也涌现出非常多的用于图像识别的深度学习模型,计算机视觉成为人工智能研究的热门领域。

大数据竞赛平台Kaggle提供了一个供机器学习爱好者自我实践的竞赛项目《Cats vs. Dogs》,在这个竞赛中,Kaggle提供了25000张猫和狗的图片作为训练数据集,提供了12500张猫和狗的图片作为测试集。

同时,为了提高模型性能,本项目另外利用了 The Oxford-IIIT Pet Dataset 7393张猫狗图片,这部分数据会作为训练集被使用。

本项目大部分模型训练的工作在AWS p3.2xlarge上完成,模型训练调整的过程共耗时70小时,其中使用p3.2xlarge 57小时,使用p2.xlarge 3小时。

整个项目的思路是利用keras的预训练模型,最开始使用VGG16,但是在验证集上logloss不变,尝试了调整激活函数、调整学习率、调整输出层激活函数,都不行,最后增加了GlobalAveragePooling2D层后,logloss才随着epoch下降,开始正常训练,但是模型的表现不好;然后尝试使用ResNet50,激活函数使用PReLU,使用全部数据训练,验证集上logloss为0.0493、准确率0.9828,kaggle分数0.12,未满足项目要求,然后调整激活函数为ELU,全连接层调整为256,优化器调整为Nadam,学习率调整为0.0005,模型在验证集上logloss下降为0.0368,准确率为0.9867,kaggle分数0.09334,未满足要求,之后尝试了各种调整,依然无法提高模型表现,然后想到补充训练集,于是使用The Oxford-IIIT Pet Dataset上7393张图上补充进训练集,模型在验证集上logloss下降为0.03286,准确率为0.9882,kaggle分数为0.07519,依然未满足要求,然后尝试使用ImageDataGenerator,依然不能提高模型表现;最后,尝试使用模型融合方法,在使用了ResNet50、InceptionV3和Xception三个模型融合后的模型,验证集上logloss下降为0.0069,准确率为0.9985,kaggle分数为0.0417,在学习使用模型融合方法的过程中,偶然学习到了clip方法,限制预测值在一个合理区间内,可以显著提高kaggle分数,因此我对之前单模型进行了clip,结果kaggle分数达到了0.05719,也满足项目要求。鉴于此,我在这个项目中同时保留了这两个模型。

载入项目所需模块

In [1]:
import os, cv2, random, pickle
from tqdm import tqdm
import numpy as np
import pandas as pd
import csv
import shutil
import h5py

from urllib.request import urlretrieve
from os.path import isfile, isdir
import utils
from utils import *
import tarfile

import matplotlib.pyplot as plt
from matplotlib import ticker
%matplotlib inline

from keras.layers import Input, Dropout, Flatten, Dense, Activation, GlobalAveragePooling2D
from keras.optimizers import RMSprop,Nadam,SGD,Adam
from keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, CSVLogger, LearningRateScheduler, ReduceLROnPlateau
from keras.utils import np_utils
from keras.models import load_model

from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import PReLU, ELU

from keras.applications.vgg16 import VGG16
from keras.applications.xception import Xception
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.applications.inception_v3 import InceptionV3
from keras.models import Model
from keras import layers

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
Using TensorFlow backend.

数据获取

由于项目模型在linux系统运行,因此使用wget通过kaggle api下载竞赛数据,下载命令kaggle competitions download -c dogs-vs-cats-redux-kernels-edition,由于下载数据需要使用kaggle账户,因此需要先对kaggle.json文件做配置。

kaggle数据集下载后是一个压缩文件,解压后可以得到训练集和测试集文件夹。训练集文件夹中包含25000张猫狗彩色图片,其中猫和狗各12500张,文件以猫狗标签以及文件编号命名;测试集文件夹中包含打乱的12500张猫狗图片,猫和狗随机分布,文件以编号命名。

Oxford-IIIT Pet数据集直接使用urlretrieve下载,下载后通过tarfile.open打开。

In [3]:
# 下载Oxford-IIIT Pet Dataset补充数据集

image_supply_path = './input/images'

image_supply_loacation = './images.tar.gz'

if isfile(image_supply_loacation):
    tar_gz_path = image_supply_loacation
else:
    tar_gz_path = 'images.tar.gz'

class DLProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

if not isfile(tar_gz_path ):
    with DLProgress(unit='B', unit_scale=True, miniters=1, desc='images supply dataset') as pbar:
        urlretrieve('http://www.robots.ox.ac.uk/%7Evgg/data/pets/data/images.tar.gz', tar_gz_path, pbar.hook)

if not isdir(image_supply_path):
    with tarfile.open(tar_gz_path) as tar:
        tar.extractall(path='./input/')
        tar.close()

test_folder_path(image_supply_path)
All files found!

数据可视化

In [11]:
TRAIN_DIR = './input/train/'
train_images_path = [TRAIN_DIR+i for i in os.listdir(TRAIN_DIR)]

# 随机展示9张图片

random.seed(2018)

def random_show(location):
    plt.subplot(location)
    sample = random.choice(train_images_path)
    img = cv2.imread(sample)
    b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
    rgb_img = cv2.merge([r,g,b])
    plt.title(sample)
    plt.imshow(rgb_img)
    
plt.figure(figsize=(12,12))
plt.subplots_adjust(wspace=0.2, hspace=0.2)
for location in range(331, 340):
    random_show(location)
plt.show()
In [12]:
# 随机展示9张补充集图片

random.seed(2018)

TRAIN_SUP_DIR = './input/images/'
train_images_path_sup = [TRAIN_SUP_DIR + file for file in os.listdir(TRAIN_SUP_DIR)]

def random_show(location):
    plt.subplot(location)
    sample = random.choice(train_images_path_sup)
    img = cv2.imread(sample)
    b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
    rgb_img = cv2.merge([r,g,b])
    plt.title(sample)
    plt.imshow(rgb_img)
    
plt.figure(figsize=(12,12))
plt.subplots_adjust(wspace=0.2, hspace=0.2)
for location in range(331, 340):
    random_show(location)
plt.show()

从上图中可以发现,图片的宽高不统一,在训练模型前,需要对图片做同一宽高处理。

另外一点,图片并不全是猫和狗的特写,图片背景比较复杂,同时光影条件也不一,在训练模型时,要防止过拟合。

Oxford-IIIT Pet数据集未使用dog和cat命名,而是以更细的分类命名,因此需要对文件名做相应的处理,由于Oxford-IIIT Pet网站给出了文件分类与猫狗的关系,所以可以对文件名做相应处理

接下来,我们对整体数据的宽高做一下可视化。

In [10]:
# 以直方图展示图片的宽高分布——kaggle训练集

height = []
width = []

for file in tqdm(train_images_path):
    image = cv2.imread(file)
    height.append(image.shape[0])
    width.append(image.shape[1])
    
plt.figure(figsize=(12,6))
plt.subplots_adjust(wspace=0.2, hspace=0.2)
    
plt.subplot(121)
plt.hist(height)
plt.title("height distribution")

plt.subplot(122)
plt.hist(width)
plt.title("width distribution")

plt.show()

print('median of height: {}'.format(np.median(height)))
print('median of width: {}'.format(np.median(width)))
100%|███████████████████████████████████████████████████████████████████████████| 25000/25000 [01:21<00:00, 305.30it/s]
median of height: 374.0
median of width: 447.0

从上图可以看到,kaggle训练集数据的宽高中位数为(447,374),由此可以得出为了提高模型性能,把图片缩放到(350,350)的比例是合适的。

In [13]:
# 以直方图展示图片的宽高分布——补充数据集

height = []
width = []

for file in tqdm(train_images_path_sup):
    image = cv2.imread(file)
    if np.any(image != None):# 当图片可以读取时,处理读取图片失败的情况
        height.append(image.shape[0])
        width.append(image.shape[1])
    
plt.figure(figsize=(12,6))
plt.subplots_adjust(wspace=0.2, hspace=0.2)
    
plt.subplot(121)
plt.hist(height)
plt.title("height distribution")

plt.subplot(122)
plt.hist(width)
plt.title("width distribution")

plt.show()

print('median of height: {}'.format(np.median(height)))
print('median of width: {}'.format(np.median(width)))
100%|██████████████████████████████████████████████████████████████████████████████| 7393/7393 [03:48<00:00, 32.42it/s]
median of height: 375.0
median of width: 500.0

从上图可以看到,补充训练集数据的宽高中位数为(500,375),由此可以得出为了提高模型性能,把图片缩放到(350,350)的比例是合适的。

In [14]:
# 以散点图展示图片的宽高分布

height_kg = []
width_kg = []
height_su = []
width_su = []

for file in train_images_path:
    image = cv2.imread(file)
    height_kg.append(image.shape[0])
    width_kg.append(image.shape[1])
    
for file in train_images_path_sup:
    image = cv2.imread(file)
    if np.any(image != None):# 当图片可以读取时,处理读取图片失败的情况
        height_su.append(image.shape[0])
        width_su.append(image.shape[1])

plt.figure(figsize=(12,6))
plt.subplots_adjust(wspace=0.2, hspace=0.2)
    
plt.subplot(121)
plt.scatter(width_kg, height_kg, s=50)
plt.title("height&width scatter of kaggle dataset")

plt.subplot(122)
plt.scatter(width_su, height_su, s=50)
plt.title("height&width scatter of supply dataset")

plt.show()

print('median of kaggle height: {}, median of kaggle width: {}'.format(np.median(height_kg),np.median(width_kg)))
print('median of supply height: {}, median of supply width: {}'.format(np.median(height_su),np.median(width_su)))
median of kaggle height: 374.0, median of kaggle width: 447.0
median of supply height: 375.0, median of supply width: 500.0

单模型——ResNet50

在项目开始时,我选择使用单预训练模型完成本次项目,预训练模型在几次尝试下,选择了ResNet50.

数据预处理

在使用训练集之前,需要对训练集进行缩放、乱序、重命名、异常值处理等预处理工作。

kaggle训练集预处理

从kaggle竞赛的讨论帖中,找到了一个网友整理的异常图片的csv文件,利用此文件对训练集做异常数据删除。

In [35]:
# 获取异常图片列表

random.seed(2018)

ab_img_list = []

csv_file = csv.reader(open('relabel.csv'))
for filename in csv_file:
    if 'dog' in filename[1] or 'cat' in filename[1]:
        if '.jpg' not in filename[1]:
            filename = filename[1] + '.jpg'
            ab_img_list.append(filename)
        else:
            ab_img_list.append(filename[1])    
    else:
        pass

for i, file in enumerate(ab_img_list):
    ab_img_list[i] = './input/train/' + file
    
# 随机展示9张异常图片
def random_show(location):
    plt.subplot(location)
    sample = random.choice(ab_img_list)
    img = cv2.imread(sample)
    b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
    rgb_img = cv2.merge([r,g,b])
    plt.title(sample)
    plt.imshow(rgb_img)
    
plt.figure(figsize=(12,12))
plt.subplots_adjust(wspace=0.2, hspace=0.2)
for location in range(331, 340):
    random_show(location)
plt.show()
In [7]:
# 删除训练集中异常的图片

ab_img_list = []

csv_file = csv.reader(open('relabel.csv'))
for filename in csv_file:
    if 'dog' in filename[1] or 'cat' in filename[1]:
        if '.jpg' not in filename[1]:
            filename = filename[1] + '.jpg'
            ab_img_list.append(filename)
        else:
            ab_img_list.append(filename[1])    
    else:
        pass

for i, file in enumerate(ab_img_list):
    ab_img_list[i] = './input/train/' + file

i = 0
for file in train_images_path:
    if file in ab_img_list:
        train_images_path.remove(file)
        i = i + 1
print('deleted {} files.'.format(i))
deleted 49 files.
In [37]:
# 展示两张打错标签的图片

wrong_label_list = ['./input/train/dog.11731.jpg', './input/train/dog.4334.jpg']


plt.figure(figsize=(8,8))
plt.subplots_adjust(wspace=0.2, hspace=0.2)

plt.subplot(121)
img = cv2.imread(wrong_label_list[0])
b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
rgb_img = cv2.merge([r,g,b])
plt.title(wrong_label_list[0])
plt.imshow(rgb_img)
    
plt.subplot(122)
img = cv2.imread(wrong_label_list[1])
b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
rgb_img = cv2.merge([r,g,b])
plt.title(wrong_label_list[1])
plt.imshow(rgb_img)

plt.show()
In [9]:
# 删除打错标签的两张图片

wrong_label_list = ['./input/train/dog.11731.jpg', './input/train/dog.4334.jpg']

i = 0
for file in train_images_path:
    if file in wrong_label_list:
        train_images_path.remove(file)
        i = i + 1
print('deleted {} files.'.format(i))
deleted 2 files.

补充训练集预处理

In [11]:
# 修改补充数据集文件名,以dog.x和cat.x命名

# utils写好了修改文件名方法:trange_file_name
trange_file_name('./input/images/')
7393it [00:00, 78278.37it/s]
In [24]:
# 找到补充训练集无法正常读取的图片

TRAIN_DIR_SUP = './input/images/'
train_images_path_sup = [TRAIN_DIR_SUP+i for i in os.listdir(TRAIN_DIR_SUP)] 


none_img_list=[]

for file in tqdm(train_images_path_sup):
    image = cv2.imread(file)
    if np.any(image == None):
        none_img_list.append(file)
        print(file)

# 删除补充集无法读取的图片
i = 0
for file in train_images_path_sup:
    if file in none_img_list:
        train_images_path_sup.remove(file)
        i = i + 1
        
print('deleted {} files.'.format(i))
 44%|████▍     | 3286/7390 [01:11<01:28, 46.24it/s]
./input/images/cat_5597.jpg
 63%|██████▎   | 4651/7390 [01:40<00:58, 46.43it/s]
./input/images/cat_1399.jpg
 72%|███████▏  | 5293/7390 [01:53<00:45, 46.60it/s]
./input/images/cat_4197.jpg
 77%|███████▋  | 5699/7390 [02:02<00:36, 46.48it/s]
./input/images/cat_5760.jpg
 93%|█████████▎| 6888/7390 [02:28<00:10, 46.46it/s]
./input/images/cat_7252.jpg
 98%|█████████▊| 7245/7390 [02:35<00:03, 46.56it/s]
./input/images/cat_2717.jpg
100%|██████████| 7390/7390 [02:38<00:00, 46.61it/s]
deleted 6 files.

kaggle训练集与补充数据集融合

In [14]:
# kaggle训练集与补充训练集融合
train_images = train_images_path + train_images_path_sup


# 对训练集数据做乱序处理

random.seed(2018)
random.shuffle(train_images)


ROWS=350
COLS=350
CHANNELS=3

def read_image(file_path):
    '''
    读取图片
    '''
    img = cv2.imread(file_path, cv2.IMREAD_COLOR) 
    return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC) # 对训练图片做缩放处理

def prep_train_data(images_path):
    '''
    训练集数据预处理
    '''
    
    # 对labels进行独热编码
    labels = np.zeros((len(images_path), 2), dtype=np.uint8)
    
    for i, path in enumerate(images_path):
        if 'dog' in path:
            labels[i][0] = 1
        else:
            labels[i][1] = 1
    
    count = len(images_path)
    features = np.ndarray((count, ROWS, COLS, CHANNELS), dtype=np.uint8)

    for i, image_file in tqdm(enumerate(images_path)):
        image = read_image(image_file)
        image.transpose((1,0,2))
        features[i] = image
#         if i%2500 == 0: print('Processed {} of {}'.format(i, count))

    return features, labels

features, labels = prep_train_data(train_images)

print("features shape: {}".format(features.shape))
print("labels shape: {}".format(labels.shape))
32333it [07:30, 71.74it/s]
features shape: (32333, 350, 350, 3)
labels shape: (32333, 2)

In [ ]:
# 训练集数据保存
pickle.dump((features, labels), open('train_data.p', 'wb'))
print('save data done!')
In [ ]:
# 训练集数据读取
features, labels = pickle.load(open('train_data_batch.p', mode='rb'))
print('load data done!')
In [18]:
# 读取数据后验证,随机展示一张图片
random.seed(2018)
image_index = random.choice(range(len(features)))
image_file = features[image_index]
plt.imshow(image_file)
plt.title('num:{}'.format(labels[image_index]))
plt.show()
In [19]:
# ResNet50 with ELU

from keras.layers.normalization import BatchNormalization
from keras.layers import GlobalAveragePooling2D
from keras.layers.advanced_activations import PReLU, ELU

# optimizer = RMSprop(lr=1e-4)
# optimizer=SGD(0.001, momentum=0.9, nesterov=True)
# optimizer = SGD
optimizer = Nadam(lr=0.0005)
objective = 'binary_crossentropy'

base_model = ResNet50(include_top=False, weights='imagenet')

for layer in base_model.layers:
            layer.trainable = False

head = base_model.output
batchnormed_1 = BatchNormalization(axis=3)(head)
avgpooled = GlobalAveragePooling2D()(batchnormed_1)
dense = Dense(256)(avgpooled)
batchnormed_2 = BatchNormalization()(dense)
relu = ELU()(batchnormed_2)
dropout = Dropout(0.2)(relu)

# dense = Dense(256)(dropout)
# batchnormed_2 = BatchNormalization()(dense)
# relu = ELU()(batchnormed_2)
# dropout = Dropout(0.2)(relu)

output = Dense(2, activation='sigmoid')(dropout)
model = Model(base_model.input, output)

model.compile(optimizer=optimizer, loss=objective, metrics=['accuracy'])
In [21]:
# 利用model_to_dot查看模型结构
SVG(model_to_dot(model).create(prog='dot', format='svg'))
Out[21]:
G 139823040569016 input_1: InputLayer 139823040567672 conv1_pad: ZeroPadding2D 139823040569016->139823040567672 139823040567392 conv1: Conv2D 139823040567672->139823040567392 139823040529856 bn_conv1: BatchNormalization 139823040567392->139823040529856 139823040097696 activation_1: Activation 139823040529856->139823040097696 139823039706840 max_pooling2d_1: MaxPooling2D 139823040097696->139823039706840 139823005614376 res2a_branch2a: Conv2D 139823039706840->139823005614376 139822910685136 res2a_branch1: Conv2D 139823039706840->139822910685136 139822996522208 bn2a_branch2a: BatchNormalization 139823005614376->139822996522208 139822996373120 activation_2: Activation 139822996522208->139822996373120 139822995978000 res2a_branch2b: Conv2D 139822996373120->139822995978000 139822911740616 bn2a_branch2b: BatchNormalization 139822995978000->139822911740616 139822911845992 activation_3: Activation 139822911740616->139822911845992 139822911391000 res2a_branch2c: Conv2D 139822911845992->139822911391000 139822911315248 bn2a_branch2c: BatchNormalization 139822911391000->139822911315248 139822865572080 bn2a_branch1: BatchNormalization 139822910685136->139822865572080 139822865819352 add_1: Add 139822911315248->139822865819352 139822865572080->139822865819352 139822865243328 activation_4: Activation 139822865819352->139822865243328 139822864986856 res2b_branch2a: Conv2D 139822865243328->139822864986856 139823005470392 add_2: Add 139822865243328->139823005470392 139822865290016 bn2b_branch2a: BatchNormalization 139822864986856->139822865290016 139822864693008 activation_5: Activation 139822865290016->139822864693008 139822863776064 res2b_branch2b: Conv2D 139822864693008->139822863776064 139822864223256 bn2b_branch2b: BatchNormalization 139822863776064->139822864223256 139823092316816 activation_6: Activation 139822864223256->139823092316816 139822996802920 res2b_branch2c: Conv2D 139823092316816->139822996802920 139823005500472 bn2b_branch2c: BatchNormalization 139822996802920->139823005500472 139823005500472->139823005470392 139823040566552 activation_7: Activation 139823005470392->139823040566552 139823040425040 res2c_branch2a: Conv2D 139823040566552->139823040425040 139823081922456 add_3: Add 139823040566552->139823081922456 139823040423752 bn2c_branch2a: BatchNormalization 139823040425040->139823040423752 139823041138416 activation_8: Activation 139823040423752->139823041138416 139823081884248 res2c_branch2b: Conv2D 139823041138416->139823081884248 139823041251984 bn2c_branch2b: BatchNormalization 139823081884248->139823041251984 139823084466640 activation_9: Activation 139823041251984->139823084466640 139823082011056 res2c_branch2c: Conv2D 139823084466640->139823082011056 139823082387048 bn2c_branch2c: BatchNormalization 139823082011056->139823082387048 139823082387048->139823081922456 139823082766968 activation_10: Activation 139823081922456->139823082766968 139823082767528 res3a_branch2a: Conv2D 139823082766968->139823082767528 139823042685416 res3a_branch1: Conv2D 139823082766968->139823042685416 139823041139992 bn3a_branch2a: BatchNormalization 139823082767528->139823041139992 139823040949160 activation_11: Activation 139823041139992->139823040949160 139823041848040 res3a_branch2b: Conv2D 139823040949160->139823041848040 139823042188400 bn3a_branch2b: BatchNormalization 139823041848040->139823042188400 139823042494024 activation_12: Activation 139823042188400->139823042494024 139823077623456 res3a_branch2c: Conv2D 139823042494024->139823077623456 139823077553040 bn3a_branch2c: BatchNormalization 139823077623456->139823077553040 139823090351856 bn3a_branch1: BatchNormalization 139823042685416->139823090351856 139823042076512 add_4: Add 139823077553040->139823042076512 139823090351856->139823042076512 139810856844256 activation_13: Activation 139823042076512->139810856844256 139810856844592 res3b_branch2a: Conv2D 139810856844256->139810856844592 139810854430088 add_5: Add 139810856844256->139810854430088 139810857021848 bn3b_branch2a: BatchNormalization 139810856844592->139810857021848 139810856729792 activation_14: Activation 139810857021848->139810856729792 139810855771160 res3b_branch2b: Conv2D 139810856729792->139810855771160 139810856238384 bn3b_branch2b: BatchNormalization 139810855771160->139810856238384 139810855561816 activation_15: Activation 139810856238384->139810855561816 139810855134376 res3b_branch2c: Conv2D 139810855561816->139810855134376 139810855066424 bn3b_branch2c: BatchNormalization 139810855134376->139810855066424 139810855066424->139810854430088 139810853911744 activation_16: Activation 139810854430088->139810853911744 139810853912192 res3c_branch2a: Conv2D 139810853911744->139810853912192 139810851536624 add_6: Add 139810853911744->139810851536624 139810854077160 bn3c_branch2a: BatchNormalization 139810853912192->139810854077160 139810853150224 activation_17: Activation 139810854077160->139810853150224 139810852881968 res3c_branch2b: Conv2D 139810853150224->139810852881968 139810853345208 bn3c_branch2b: BatchNormalization 139810852881968->139810853345208 139810852702024 activation_18: Activation 139810853345208->139810852702024 139810852249384 res3c_branch2c: Conv2D 139810852702024->139810852249384 139810852167128 bn3c_branch2c: BatchNormalization 139810852249384->139810852167128 139810852167128->139810851536624 139810851084104 activation_19: Activation 139810851536624->139810851084104 139810851084440 res3d_branch2a: Conv2D 139810851084104->139810851084440 139809422301616 add_7: Add 139810851084104->139809422301616 139810851242280 bn3d_branch2a: BatchNormalization 139810851084440->139810851242280 139810850720848 activation_20: Activation 139810851242280->139810850720848 139810849995184 res3d_branch2b: Conv2D 139810850720848->139810849995184 139810850447312 bn3d_branch2b: BatchNormalization 139810849995184->139810850447312 139810849887848 activation_21: Activation 139810850447312->139810849887848 139810849421296 res3d_branch2c: Conv2D 139810849887848->139810849421296 139810849360584 bn3d_branch2c: BatchNormalization 139810849421296->139810849360584 139810849360584->139809422301616 139809421770368 activation_22: Activation 139809422301616->139809421770368 139809421770704 res4a_branch2a: Conv2D 139809421770368->139809421770704 139809419461968 res4a_branch1: Conv2D 139809421770368->139809419461968 139809421937688 bn4a_branch2a: BatchNormalization 139809421770704->139809421937688 139809420928000 activation_23: Activation 139809421937688->139809420928000 139809420753440 res4a_branch2b: Conv2D 139809420928000->139809420753440 139809420692392 bn4a_branch2b: BatchNormalization 139809420753440->139809420692392 139809420549144 activation_24: Activation 139809420692392->139809420549144 139809419655432 res4a_branch2c: Conv2D 139809420549144->139809419655432 139809420110816 bn4a_branch2c: BatchNormalization 139809419655432->139809420110816 139809418567920 bn4a_branch1: BatchNormalization 139809419461968->139809418567920 139809418292976 add_8: Add 139809420110816->139809418292976 139809418567920->139809418292976 139809417796912 activation_25: Activation 139809418292976->139809417796912 139809417797360 res4b_branch2a: Conv2D 139809417796912->139809417797360 139809415490416 add_9: Add 139809417796912->139809415490416 139809417982808 bn4b_branch2a: BatchNormalization 139809417797360->139809417982808 139809416991856 activation_26: Activation 139809417982808->139809416991856 139809416857864 res4b_branch2b: Conv2D 139809416991856->139809416857864 139809416788960 bn4b_branch2b: BatchNormalization 139809416857864->139809416788960 139809416656264 activation_27: Activation 139809416788960->139809416656264 139809415686576 res4b_branch2c: Conv2D 139809416656264->139809415686576 139809416158008 bn4b_branch2c: BatchNormalization 139809415686576->139809416158008 139809416158008->139809415490416 139809414697872 activation_28: Activation 139809415490416->139809414697872 139809414830624 res4c_branch2a: Conv2D 139809414697872->139809414830624 139809412680560 add_10: Add 139809414697872->139809412680560 139809415050352 bn4c_branch2a: BatchNormalization 139809414830624->139809415050352 139809414147992 activation_29: Activation 139809415050352->139809414147992 139809413941624 res4c_branch2b: Conv2D 139809414147992->139809413941624 139809413898080 bn4c_branch2b: BatchNormalization 139809413941624->139809413898080 139809413416160 activation_30: Activation 139809413898080->139809413416160 139809412863928 res4c_branch2c: Conv2D 139809413416160->139809412863928 139809412807312 bn4c_branch2c: BatchNormalization 139809412863928->139809412807312 139809412807312->139809412680560 139809412148696 activation_31: Activation 139809412680560->139809412148696 139809411520608 res4d_branch2a: Conv2D 139809412148696->139809411520608 139809409881984 add_11: Add 139809412148696->139809409881984 139809412148920 bn4d_branch2a: BatchNormalization 139809411520608->139809412148920 139809411336512 activation_32: Activation 139809412148920->139809411336512 139809411153704 res4d_branch2b: Conv2D 139809411336512->139809411153704 139809411087720 bn4d_branch2b: BatchNormalization 139809411153704->139809411087720 139809410615280 activation_33: Activation 139809411087720->139809410615280 139809410053568 res4d_branch2c: Conv2D 139809410615280->139809410053568 139809409781600 bn4d_branch2c: BatchNormalization 139809410053568->139809409781600 139809409781600->139809409881984 139809409367288 activation_34: Activation 139809409881984->139809409367288 139809409367400 res4e_branch2a: Conv2D 139809409367288->139809409367400 139809406604680 add_12: Add 139809409367288->139809406604680 139809408740320 bn4e_branch2a: BatchNormalization 139809409367400->139809408740320 139809408547752 activation_35: Activation 139809408740320->139809408547752 139809407871968 res4e_branch2b: Conv2D 139809408547752->139809407871968 139809407811256 bn4e_branch2b: BatchNormalization 139809407871968->139809407811256 139809407679120 activation_36: Activation 139809407811256->139809407679120 139809407220480 res4e_branch2c: Conv2D 139809407679120->139809407220480 139809407221040 bn4e_branch2c: BatchNormalization 139809407220480->139809407221040 139809407221040->139809406604680 139809406073992 activation_37: Activation 139809406604680->139809406073992 139809406074440 res4f_branch2a: Conv2D 139809406073992->139809406074440 139809403835728 add_13: Add 139809406073992->139809403835728 139809405727408 bn4f_branch2a: BatchNormalization 139809406074440->139809405727408 139809405289416 activation_38: Activation 139809405727408->139809405289416 139809405102568 res4f_branch2b: Conv2D 139809405289416->139809405102568 139809405037424 bn4f_branch2b: BatchNormalization 139809405102568->139809405037424 139809404886040 activation_39: Activation 139809405037424->139809404886040 139809404029136 res4f_branch2c: Conv2D 139809404886040->139809404029136 139809403952264 bn4f_branch2c: BatchNormalization 139809404029136->139809403952264 139809403952264->139809403835728 139809403341568 activation_40: Activation 139809403835728->139809403341568 139809403341904 res5a_branch2a: Conv2D 139809403341568->139809403341904 139809400583672 res5a_branch1: Conv2D 139809403341568->139809400583672 139809402966200 bn5a_branch2a: BatchNormalization 139809403341904->139809402966200 139809402544592 activation_41: Activation 139809402966200->139809402544592 139809401812976 res5a_branch2b: Conv2D 139809402544592->139809401812976 139809402276552 bn5a_branch2b: BatchNormalization 139809401812976->139809402276552 139809401633200 activation_42: Activation 139809402276552->139809401633200 139809401198912 res5a_branch2c: Conv2D 139809401633200->139809401198912 139809401126528 bn5a_branch2c: BatchNormalization 139809401198912->139809401126528 139809400060168 bn5a_branch1: BatchNormalization 139809400583672->139809400060168 139809399434488 add_14: Add 139809401126528->139809399434488 139809400060168->139809399434488 139809398909752 activation_43: Activation 139809399434488->139809398909752 139809398910088 res5b_branch2a: Conv2D 139809398909752->139809398910088 139809396681472 add_15: Add 139809398909752->139809396681472 139809399103728 bn5b_branch2a: BatchNormalization 139809398910088->139809399103728 139809398153736 activation_44: Activation 139809399103728->139809398153736 139809397922168 res5b_branch2b: Conv2D 139809398153736->139809397922168 139809397894616 bn5b_branch2b: BatchNormalization 139809397922168->139809397894616 139809397833968 activation_45: Activation 139809397894616->139809397833968 139809396848456 res5b_branch2c: Conv2D 139809397833968->139809396848456 139809397303840 bn5b_branch2c: BatchNormalization 139809396848456->139809397303840 139809397303840->139809396681472 139809396165992 activation_46: Activation 139809396681472->139809396165992 139809395988464 res5c_branch2a: Conv2D 139809396165992->139809395988464 139809393391376 add_16: Add 139809396165992->139809393391376 139809396166216 bn5c_branch2a: BatchNormalization 139809395988464->139809396166216 139809395345616 activation_47: Activation 139809396166216->139809395345616 139809395166792 res5c_branch2b: Conv2D 139809395345616->139809395166792 139809395096936 bn5c_branch2b: BatchNormalization 139809395166792->139809395096936 139809394473952 activation_48: Activation 139809395096936->139809394473952 139809394087248 res5c_branch2c: Conv2D 139809394473952->139809394087248 139809394022104 bn5c_branch2c: BatchNormalization 139809394087248->139809394022104 139809394022104->139809393391376 139809392889360 activation_49: Activation 139809393391376->139809392889360 139823040566720 batch_normalization_1: BatchNormalization 139809392889360->139823040566720 139823040568792 global_average_pooling2d_1: GlobalAveragePooling2D 139823040566720->139823040568792 139809392786288 dense_1: Dense 139823040568792->139809392786288 139809101086616 batch_normalization_2: BatchNormalization 139809392786288->139809101086616 139809100198464 elu_1: ELU 139809101086616->139809100198464 139809376976848 dropout_1: Dropout 139809100198464->139809376976848 139808900722640 dense_2: Dense 139809376976848->139808900722640
In [22]:
# 查看模型参数
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, None, None, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, None, None, 6 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, None, None, 6 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 6 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, None, None, 6 0           activation_1[0][0]               
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, None, None, 6 4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, None, None, 6 256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, None, None, 6 0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, None, None, 6 36928       activation_2[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, None, None, 6 256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, None, None, 6 0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, None, None, 2 16640       activation_3[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, None, None, 2 16640       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, None, None, 2 1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, None, None, 2 1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, None, 2 0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_4 (Activation)       (None, None, None, 2 0           add_1[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, None, None, 6 16448       activation_4[0][0]               
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, None, None, 6 256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_5 (Activation)       (None, None, None, 6 0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, None, None, 6 36928       activation_5[0][0]               
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, None, None, 6 256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_6 (Activation)       (None, None, None, 6 0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, None, None, 2 16640       activation_6[0][0]               
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, None, None, 2 1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, None, 2 0           bn2b_branch2c[0][0]              
                                                                 activation_4[0][0]               
__________________________________________________________________________________________________
activation_7 (Activation)       (None, None, None, 2 0           add_2[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, None, None, 6 16448       activation_7[0][0]               
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, None, None, 6 256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, None, None, 6 0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, None, None, 6 36928       activation_8[0][0]               
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, None, None, 6 256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, None, None, 6 0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, None, None, 2 16640       activation_9[0][0]               
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, None, None, 2 1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, None, 2 0           bn2c_branch2c[0][0]              
                                                                 activation_7[0][0]               
__________________________________________________________________________________________________
activation_10 (Activation)      (None, None, None, 2 0           add_3[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, None, None, 1 32896       activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, None, None, 1 512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, None, None, 1 0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, None, None, 1 147584      activation_11[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, None, None, 1 512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_12 (Activation)      (None, None, None, 1 0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, None, None, 5 66048       activation_12[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, None, None, 5 131584      activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, None, None, 5 2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, None, None, 5 2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, None, 5 0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_13 (Activation)      (None, None, None, 5 0           add_4[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, None, None, 1 65664       activation_13[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, None, None, 1 512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, None, None, 1 0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, None, None, 1 147584      activation_14[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, None, None, 1 512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, None, None, 1 0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, None, None, 5 66048       activation_15[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, None, None, 5 2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, None, 5 0           bn3b_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_16 (Activation)      (None, None, None, 5 0           add_5[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, None, None, 1 65664       activation_16[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, None, None, 1 512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, None, None, 1 0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, None, None, 1 147584      activation_17[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, None, None, 1 512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, None, None, 1 0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, None, None, 5 66048       activation_18[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, None, None, 5 2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, None, 5 0           bn3c_branch2c[0][0]              
                                                                 activation_16[0][0]              
__________________________________________________________________________________________________
activation_19 (Activation)      (None, None, None, 5 0           add_6[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, None, None, 1 65664       activation_19[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, None, None, 1 512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, None, None, 1 0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, None, None, 1 147584      activation_20[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, None, None, 1 512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, None, None, 1 0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, None, None, 5 66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, None, None, 5 2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, None, 5 0           bn3d_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, None, None, 5 0           add_7[0][0]                      
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, None, None, 2 131328      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, None, None, 2 1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, None, None, 2 0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, None, None, 2 590080      activation_23[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, None, None, 2 1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_24 (Activation)      (None, None, None, 2 0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, None, None, 1 263168      activation_24[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, None, None, 1 525312      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, None, None, 1 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, None, None, 1 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, None, 1 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_25 (Activation)      (None, None, None, 1 0           add_8[0][0]                      
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, None, None, 2 262400      activation_25[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, None, None, 2 1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, None, None, 2 0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, None, None, 2 590080      activation_26[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, None, None, 2 1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_27 (Activation)      (None, None, None, 2 0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, None, None, 1 263168      activation_27[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, None, None, 1 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, None, None, 1 0           bn4b_branch2c[0][0]              
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
activation_28 (Activation)      (None, None, None, 1 0           add_9[0][0]                      
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, None, None, 2 262400      activation_28[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, None, None, 2 1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_29 (Activation)      (None, None, None, 2 0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, None, None, 2 590080      activation_29[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, None, None, 2 1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_30 (Activation)      (None, None, None, 2 0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, None, None, 1 263168      activation_30[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, None, None, 1 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, None, None, 1 0           bn4c_branch2c[0][0]              
                                                                 activation_28[0][0]              
__________________________________________________________________________________________________
activation_31 (Activation)      (None, None, None, 1 0           add_10[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, None, None, 2 262400      activation_31[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, None, None, 2 1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_32 (Activation)      (None, None, None, 2 0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, None, None, 2 590080      activation_32[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, None, None, 2 1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_33 (Activation)      (None, None, None, 2 0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, None, None, 1 263168      activation_33[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, None, None, 1 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, None, None, 1 0           bn4d_branch2c[0][0]              
                                                                 activation_31[0][0]              
__________________________________________________________________________________________________
activation_34 (Activation)      (None, None, None, 1 0           add_11[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, None, None, 2 262400      activation_34[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, None, None, 2 1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_35 (Activation)      (None, None, None, 2 0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, None, None, 2 590080      activation_35[0][0]              
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, None, None, 2 1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_36 (Activation)      (None, None, None, 2 0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, None, None, 1 263168      activation_36[0][0]              
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, None, None, 1 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, None, None, 1 0           bn4e_branch2c[0][0]              
                                                                 activation_34[0][0]              
__________________________________________________________________________________________________
activation_37 (Activation)      (None, None, None, 1 0           add_12[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, None, None, 2 262400      activation_37[0][0]              
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, None, None, 2 1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_38 (Activation)      (None, None, None, 2 0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, None, None, 2 590080      activation_38[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, None, None, 2 1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_39 (Activation)      (None, None, None, 2 0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, None, None, 1 263168      activation_39[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, None, None, 1 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_13 (Add)                    (None, None, None, 1 0           bn4f_branch2c[0][0]              
                                                                 activation_37[0][0]              
__________________________________________________________________________________________________
activation_40 (Activation)      (None, None, None, 1 0           add_13[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, None, None, 5 524800      activation_40[0][0]              
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, None, None, 5 2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_41 (Activation)      (None, None, None, 5 0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_41[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, None, None, 5 2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_42 (Activation)      (None, None, None, 5 0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_42[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, None, None, 2 2099200     activation_40[0][0]              
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, None, None, 2 8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, None, None, 2 8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_14 (Add)                    (None, None, None, 2 0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_43 (Activation)      (None, None, None, 2 0           add_14[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, None, None, 5 1049088     activation_43[0][0]              
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, None, None, 5 2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_44 (Activation)      (None, None, None, 5 0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_44[0][0]              
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, None, None, 5 2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_45 (Activation)      (None, None, None, 5 0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_45[0][0]              
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, None, None, 2 8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, None, None, 2 0           bn5b_branch2c[0][0]              
                                                                 activation_43[0][0]              
__________________________________________________________________________________________________
activation_46 (Activation)      (None, None, None, 2 0           add_15[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, None, None, 5 1049088     activation_46[0][0]              
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, None, None, 5 2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_47 (Activation)      (None, None, None, 5 0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, None, None, 5 2359808     activation_47[0][0]              
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, None, None, 5 2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_48 (Activation)      (None, None, None, 5 0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, None, None, 2 1050624     activation_48[0][0]              
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, None, None, 2 8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, None, None, 2 0           bn5c_branch2c[0][0]              
                                                                 activation_46[0][0]              
__________________________________________________________________________________________________
activation_49 (Activation)      (None, None, None, 2 0           add_16[0][0]                     
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 2 8192        activation_49[0][0]              
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 2048)         0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 256)          524544      global_average_pooling2d_1[0][0] 
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 256)          1024        dense_1[0][0]                    
__________________________________________________________________________________________________
elu_1 (ELU)                     (None, 256)          0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 256)          0           elu_1[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 2)            514         dropout_1[0][0]                  
==================================================================================================
Total params: 24,121,986
Trainable params: 529,666
Non-trainable params: 23,592,320
__________________________________________________________________________________________________
In [23]:
# 模型训练

random.seed(2018)

nb_epoch = 20
batch_size = 128

# 保存训练过程中验证集上表现最好的模型
val_checkpoint = ModelCheckpoint('resnet_bestval_{val_loss:.4f}.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')
# cur_checkpoint = ModelCheckpoint('current.h5')

# 当模型在2个epoch上未提高时,降低2倍学习率
lrSchduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, cooldown=1, verbose=1)

# 自定义callback回调函数,在epoch结束时写入loss和val_loss
class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.val_losses = []
        
    def on_epoch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

# 为了减少模型训练时间,同时防止过拟合,使用early_stopping在模型不提高性能的5个epoch后停止训练
early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='min')


# 由于训练数据是按照猫和狗依次排序的,因此在训练过程中对训练集做乱序处理,shuffle=True
def run_catdog():
    
    history = LossHistory()
    model.fit(features, labels, batch_size=batch_size, epochs=nb_epoch,
              validation_split=0.2, verbose=1, shuffle=True, 
              callbacks=[history, early_stopping, val_checkpoint,lrSchduler])
    
    return history

history = run_catdog()
Train on 25866 samples, validate on 6467 samples
Epoch 1/20
25866/25866 [==============================] - 696s 27ms/step - loss: 0.0550 - acc: 0.9804 - val_loss: 0.0494 - val_acc: 0.9821

Epoch 00001: val_loss improved from inf to 0.04945, saving model to resnet_bestval_0.0494.h5
Epoch 2/20
25866/25866 [==============================] - 628s 24ms/step - loss: 0.0252 - acc: 0.9912 - val_loss: 0.0371 - val_acc: 0.9860

Epoch 00002: val_loss improved from 0.04945 to 0.03712, saving model to resnet_bestval_0.0371.h5
Epoch 3/20
25866/25866 [==============================] - 628s 24ms/step - loss: 0.0176 - acc: 0.9939 - val_loss: 0.0398 - val_acc: 0.9869

Epoch 00003: val_loss did not improve from 0.03712
Epoch 4/20
25866/25866 [==============================] - 628s 24ms/step - loss: 0.0179 - acc: 0.9933 - val_loss: 0.0374 - val_acc: 0.9871

Epoch 00004: val_loss did not improve from 0.03712

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
Epoch 5/20
25866/25866 [==============================] - 628s 24ms/step - loss: 0.0101 - acc: 0.9966 - val_loss: 0.0402 - val_acc: 0.9862

Epoch 00005: val_loss did not improve from 0.03712
Epoch 6/20
25866/25866 [==============================] - 628s 24ms/step - loss: 0.0086 - acc: 0.9976 - val_loss: 0.0401 - val_acc: 0.9862

Epoch 00006: val_loss did not improve from 0.03712

Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.
Epoch 7/20
25866/25866 [==============================] - 628s 24ms/step - loss: 0.0067 - acc: 0.9983 - val_loss: 0.0381 - val_acc: 0.9876

Epoch 00007: val_loss did not improve from 0.03712
Epoch 00007: early stopping
In [ ]:
# 模型训练过程可视化

loss = history.losses
val_loss = history.val_losses


plt.figure(figsize=(12,6))

plt.plot(loss, 'blue', label='Training Loss')
plt.plot(val_loss, 'green', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Trend')
plt.legend()

plt.show()
In [38]:
# 读取在验证集上表现最优的模型,对测试集做预测
model = load_model('bestval_0.0329.h5')
In [3]:
# 测试集数据预处理

ROWS = 350
COLS = 350
CHANNELS = 3

TEST_DIR = './test/test1/'

filenames = os.listdir(TEST_DIR)
filenames.sort(key=lambda x:int(x[:-4])) # 按文件名大小顺序排列

test_images_path =  [TEST_DIR+i for i in filenames]

test_images_path[:10]
Out[3]:
['./test/test1/1.jpg',
 './test/test1/2.jpg',
 './test/test1/3.jpg',
 './test/test1/4.jpg',
 './test/test1/5.jpg',
 './test/test1/6.jpg',
 './test/test1/7.jpg',
 './test/test1/8.jpg',
 './test/test1/9.jpg',
 './test/test1/10.jpg']
In [4]:
ROWS=350
COLS=350
CHANNELS=3

def read_image(file_path):
    '''
    读取图片
    '''
    img = cv2.imread(file_path, cv2.IMREAD_COLOR) 
    b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
    rgb_img = cv2.merge([r,g,b])
    return cv2.resize(rgb_img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC) # 对训练图片做缩放处理

def prep_test_data(images_path):
    '''
    测试集数据预处理
    '''
    count = len(images_path)
    test_images = np.ndarray((count, ROWS, COLS, CHANNELS), dtype=np.uint8)

    for i, image_file in tqdm(enumerate(images_path)):
        image = read_image(image_file)
        image.transpose(1,0,2)
        test_images[i] = image

    return test_images

test_images = prep_test_data(test_images_path)
12500it [02:07, 98.30it/s] 
In [ ]:
# 保存测试集数据

pickle.dump(test_images, open('test_data_350.p', 'wb'))
print('save data done!')
In [14]:
# 读取测试集数据

test_images = pickle.load(open('test_data_350.p', mode='rb'))
print('load data done!')
---------------------------------------------------------------------------
EOFError                                  Traceback (most recent call last)
<ipython-input-14-0b380f4cded0> in <module>()
----> 1 test_images = pickle.load(open('test_data_350.p', mode='rb'))
      2 print('load data done!')

EOFError: Ran out of input
3532it [00:40, 88.30it/s] 
In [85]:
# 利用训练好的模型对测试集数据做预测

predictions = model.predict(test_images,batch_size=128, verbose=1)
12500/12500 [==============================] - 256s 20ms/step
In [7]:
# 随机展示9条预测结果

random.seed(2018)

plt.figure(figsize=(12,12))
plt.subplots_adjust(wspace=0.2, hspace=0.4)
for location in range(331, 340):
    plt.subplot(location)
    for i in np.random.randint(low=0,high=12500,size=1):
        if predictions[i, 0] >= 0.5: 
            title = 'I am {:.3%} sure this is a Dog'.format(predictions[i][0])
        else: 
            title = 'I am {:.3%} sure this is a Cat'.format(predictions[i][1])
    plt.title(title)
    plt.imshow(test_images[i])
    
plt.show()
In [87]:
#由于kaggle采用log_loss作为评分标准,参考log_loss对无穷大问题的处理,使用clip对预测值空间做限制,能显著提高kaggle分数

predictions = predictions.clip(min=0.005, max=0.995)
In [88]:
# 把预测结果以kaggle规定格式和文件顺序写入csv文件

with open('submission_0.0328.csv','w') as f:
    f.write('id,label\n')
            
with open('submission_0.0328.csv','a') as f:
    num = len(predictions)
    pred = 0
    
    for i in tqdm(range(0,num)):
        pred = predictions[i, 0]
        
        f.write('{},{}\n'.format(i+1, pred))
        
f.close()
print('file closed!')
100%|██████████| 12500/12500 [00:00<00:00, 376849.43it/s]
file closed!

使用多模型融合

使用多模型融合的基本思路是:使用ImageDataGenerator做图片处理,使用predict.generator获取预训练模型的特征向量,融合多个特征向量作为模型的收入,模型添加分类器后,可直接训练。

由于ImageDataGenerator要求训练集按照分类存放在不同的文件夹下,而kaggle训练集和补充数据集都没有按照猫和狗文件夹存放,因此需要对文件做转移,为了不破坏单模型的数据集,本项目使用shutil.copy方法做文件复制。

In [15]:
# 创建目录

import os
import shutil


TTRAIN_DIR = './input/train/'
train_images = os.listdir(TTRAIN_DIR)
train_cats = [file for file in train_images if 'cat' in file]
train_dogs = [file for file in train_images if 'dog' in file]

TEST_DIR = './input/test/'
test_images = [file for file in os.listdir(TEST_DIR)]

def mkdir(path):
 
    isExists=os.path.exists(path)
 
    if not isExists:
        os.makedirs(path) 
 
        print(':{} 创建成功'.format(path))
        return True
    else:
        print(':{} 目录已存在'.format(path))
        return False
 
# 定义要创建的目录
mkpath_list=['./mydata2/', './mydata2/train/', './mydata2/train/cats/', 
             './mydata2/train/dogs/', './mydata2/validation/', './mydata2/test1/', './mydata2/test1/test/']

for path in mkpath_list:
    mkdir(path)
:./mydata2/ 创建成功
:./mydata2/train/ 创建成功
:./mydata2/train/cats/ 创建成功
:./mydata2/train/dogs/ 创建成功
:./mydata2/validation/ 创建成功
:./mydata2/test1/ 创建成功
:./mydata2/test1/test/ 创建成功
In [16]:
def copyfile(path, str, new_path):
    '''
    指定文件复制
    '''
    i = 0
    for file in path:
        if str in file:
            shutil.copy(file, new_path)
            i = i+1
    print('copyed {} {} images to {}'.format(i, str, new_path))
In [17]:
def copyfileno(path, new_path):
    '''
    不指定文件复制
    '''
    i = 0
    for file in path:
        shutil.copy(file, new_path)
        i = i+1
    print('copyed {} images to {}'.format(i, new_path))
In [18]:
TRAIN_DIR = './input/train/'
train_images = [TRAIN_DIR + file for file in os.listdir(TTRAIN_DIR)]

TEST_DIR = './input/test/'
test_images = [TEST_DIR + file for file in os.listdir(TEST_DIR)]

TRAIN_SUP_DIR = './input/images/'
train_sup_images = [TRAIN_SUP_DIR + file for file in os.listdir(TRAIN_SUP_DIR)]

# 把kaggle训练集、补充数据集和kaggle测试集的数据复制一份到相应的文件夹下
copyfile(train_images, 'dog', './mydata2/train/dogs/')
copyfile(train_images, 'cat', './mydata2/train/cats/')
copyfile(train_sup_images, 'dog', './mydata2/train/dogs/')
copyfile(train_sup_images, 'cat', './mydata2/train/cats/')
copyfileno(test_images, './mydata2/test1/test/')
copyed 12500 dog images to ./mydata2/train/dogs/
copyed 12500 cat images to ./mydata2/train/cats/
copyed 4990 dog images to ./mydata2/train/dogs/
copyed 2400 cat images to ./mydata2/train/cats/
copyed 12500 images to ./mydata2/test1/test/
In [25]:
# 从文件夹中删除异常图片文件--49张

i = 0
for file in os.listdir('./mydata2/train/dogs/'):
    if file in ab_img_list:
        os.remove('./mydata2/train/dogs/' + file)
        i = i + 1

for file in os.listdir('./mydata2/train/cats/'):
    if file in ab_img_list:
        os.remove('./mydata2/train/cats/' + file)
        i = i + 1

print('deleted {} files.'.format(i))
deleted 0 files.
In [26]:
# 从文件夹中删除错误标记的图片--2张

i = 0
for file in os.listdir('./mydata2/train/dogs/'):
    if file in wrong_label_list:
        os.remove('./mydata2/train/dogs/' + file)
        i = i + 1

for file in os.listdir('./mydata2/train/cats/'):
    if file in wrong_label_list:
        os.remove('./mydata2/train/cats/' + file)
        i = i + 1

print('deleted {} files.'.format(i))
deleted 0 files.
In [27]:
#删除补充数据集中无法读取的异常图片
        
file_list = []
for path in none_img_list:
    index = path.rfind('/')
    file = path[index+1:]
    file_list.append(file)

i = 0
for file in os.listdir('./mydata2/train/dogs/'):
    if file in file_list:
        os.remove('./mydata2/train/cats/' + file)
        i = i + 1

for file in os.listdir('./mydata2/train/cats/'):
    if file in file_list:
        os.remove('./mydata2/train/cats/' + file)
        i = i + 1
        
print('deleted {} files.'.format(i))
deleted 6 files.
In [31]:
def write_feature_vectors(MODEL, image_size, lambda_func=None):
    '''
    获取训练数据和测试数据的基于keras预训练模型的特征向量
    '''
    width = image_size[0]
    height = image_size[1]
    input_tensor = Input((height, width, 3))
    x = input_tensor
    if lambda_func:
        x = Lambda(lambda_func)(x)
    
    base_model = MODEL(input_tensor=x, weights='imagenet', include_top=False)
    model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))

    gen = ImageDataGenerator()
    train_generator = gen.flow_from_directory("./mydata2/train", image_size, shuffle=False, batch_size=16)
    test_generator = gen.flow_from_directory("./mydata2/test1", image_size, shuffle=False, batch_size=16, class_mode=None)

    train = model.predict_generator(train_generator, verbose=1)
    test = model.predict_generator(test_generator, verbose=1)
    with h5py.File('fv_{}_{}.h5'.format(MODEL.__name__, width)) as h:
        h.create_dataset("train", data=train)
        h.create_dataset("test", data=test)
        h.create_dataset("label", data=train_generator.classes)
In [32]:
# 使用模型默认图片大小

# write_feature_vectors(VGG16, (224, 224))
write_feature_vectors(ResNet50, (224, 224))
write_feature_vectors(InceptionV3, (299, 299), inception_v3.preprocess_input)
write_feature_vectors(Xception, (299, 299), xception.preprocess_input)
In [33]:
# 使用350*350图片大小

# write_feature_vectors(VGG16, (350, 350))
write_feature_vectors(ResNet50, (350, 350))
write_feature_vectors(InceptionV3, (350, 350), inception_v3.preprocess_input)
write_feature_vectors(Xception, (350, 350), xception.preprocess_input)
In [2]:
# 模型融合

random.seed(2018)

X_train = []
X_test = []

for filename in ["fv_ResNet50.h5", "fv_Xception.h5", "fv_InceptionV3.h5"]:
    with h5py.File(filename, 'r') as h:
        X_train.append(np.array(h['train']))
        X_test.append(np.array(h['test']))
        y_train = np.array(h['label'])

X_train = np.concatenate(X_train, axis=1)
X_test = np.concatenate(X_test, axis=1)
In [46]:
# 模型构建,直接在预训练模型后加分类器

input_tensor = Input(X_train.shape[1:])
x = Dropout(0.5)(input_tensor)
x = Dense(1, activation='sigmoid')(x)
model = Model(input_tensor, x)

model.compile(optimizer='adadelta',
              loss='binary_crossentropy',
              metrics=['accuracy'])
In [37]:
# 利用model_to_dot查看模型结构
SVG(model_to_dot(model).create(prog='dot', format='svg'))
Out[37]:
G 140603137416776 input_1: InputLayer 140602775390584 dropout_1: Dropout 140603137416776->140602775390584 140602772061544 dense_1: Dense 140602775390584->140602772061544
In [38]:
# 查看模型参数
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 6144)              0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 6144)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 6145      
=================================================================
Total params: 6,145
Trainable params: 6,145
Non-trainable params: 0
_________________________________________________________________
In [45]:
# 模型训练


nb_epoch = 20
batch_size = 128

# 保存训练过程中验证集上表现最好的模型
val_checkpoint = ModelCheckpoint('resnet_bestval_{val_loss:.4f}.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')
# cur_checkpoint = ModelCheckpoint('current.h5')

# 当模型在2个epoch上未提高时,降低2倍学习率
lrSchduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, cooldown=1, verbose=1)

# 自定义callback回调函数,在epoch结束时写入loss和val_loss
class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.val_losses = []
        
    def on_epoch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

# 为了减少模型训练时间,同时防止过拟合,使用early_stopping在模型不提高性能的5个epoch后停止训练
early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='min')


# 在训练过程中对训练集做乱序处理,shuffle=True
def run_catdog():
    
    history = LossHistory()
    model.fit(X_train, y_train, batch_size=batch_size, epochs=nb_epoch,
              validation_split=0.2, verbose=1, shuffle=True, 
              callbacks=[history, early_stopping, val_checkpoint,lrSchduler])
    
    return history

history = run_catdog()
Train on 25867 samples, validate on 6467 samples
Epoch 1/20
25867/25867 [==============================] - 2s 77us/step - loss: 0.0640 - acc: 0.9781 - val_loss: 0.0407 - val_acc: 0.9934

Epoch 00001: val_loss improved from inf to 0.04074, saving model to resnet_bestval_0.0407.h5
Epoch 2/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0168 - acc: 0.9951 - val_loss: 0.0161 - val_acc: 0.9969

Epoch 00002: val_loss improved from 0.04074 to 0.01613, saving model to resnet_bestval_0.0161.h5
Epoch 3/20
25867/25867 [==============================] - 1s 43us/step - loss: 0.0145 - acc: 0.9956 - val_loss: 0.0135 - val_acc: 0.9964

Epoch 00003: val_loss improved from 0.01613 to 0.01349, saving model to resnet_bestval_0.0135.h5
Epoch 4/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0121 - acc: 0.9959 - val_loss: 0.0158 - val_acc: 0.9960

Epoch 00004: val_loss did not improve from 0.01349
Epoch 5/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0112 - acc: 0.9961 - val_loss: 0.0099 - val_acc: 0.9972

Epoch 00005: val_loss improved from 0.01349 to 0.00985, saving model to resnet_bestval_0.0099.h5
Epoch 6/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0098 - acc: 0.9967 - val_loss: 0.0130 - val_acc: 0.9966

Epoch 00006: val_loss did not improve from 0.00985
Epoch 7/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0090 - acc: 0.9969 - val_loss: 0.0084 - val_acc: 0.9978

Epoch 00007: val_loss improved from 0.00985 to 0.00839, saving model to resnet_bestval_0.0084.h5
Epoch 8/20
25867/25867 [==============================] - 1s 45us/step - loss: 0.0083 - acc: 0.9974 - val_loss: 0.0087 - val_acc: 0.9972

Epoch 00008: val_loss did not improve from 0.00839
Epoch 9/20
25867/25867 [==============================] - 1s 45us/step - loss: 0.0085 - acc: 0.9974 - val_loss: 0.0060 - val_acc: 0.9983

Epoch 00009: val_loss improved from 0.00839 to 0.00600, saving model to resnet_bestval_0.0060.h5
Epoch 10/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0082 - acc: 0.9973 - val_loss: 0.0111 - val_acc: 0.9969

Epoch 00010: val_loss did not improve from 0.00600
Epoch 11/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0078 - acc: 0.9976 - val_loss: 0.0086 - val_acc: 0.9975

Epoch 00011: val_loss did not improve from 0.00600

Epoch 00011: ReduceLROnPlateau reducing learning rate to 0.5.
Epoch 12/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0061 - acc: 0.9982 - val_loss: 0.0183 - val_acc: 0.9947

Epoch 00012: val_loss did not improve from 0.00600
Epoch 13/20
25867/25867 [==============================] - 1s 44us/step - loss: 0.0062 - acc: 0.9981 - val_loss: 0.0110 - val_acc: 0.9972

Epoch 00013: val_loss did not improve from 0.00600

Epoch 00013: ReduceLROnPlateau reducing learning rate to 0.25.
Epoch 14/20
25867/25867 [==============================] - 1s 45us/step - loss: 0.0058 - acc: 0.9982 - val_loss: 0.0136 - val_acc: 0.9963

Epoch 00014: val_loss did not improve from 0.00600
Epoch 00014: early stopping
In [76]:
# 模型训练过程可视化

loss = history.losses
val_loss = history.val_losses


plt.figure(figsize=(12,6))

plt.plot(loss, 'blue', label='Training Loss')
plt.plot(val_loss, 'green', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Trend')
plt.legend()

plt.show()
In [3]:
# 加载在验证集上表现最优的模型
model = load_model('mix_bestval_0.0069.h5')
In [4]:
# 利用训练好的模型对测试集做预测

predictions = model.predict(X_test, verbose=1)
12500/12500 [==============================] - 1s 50us/step
In [8]:
predictions = pickle.load(open('predictions.p', mode='rb'))
In [5]:
predictions[:10]
Out[5]:
array([[9.9984729e-01],
       [6.4735364e-06],
       [1.6646410e-04],
       [9.9998653e-01],
       [9.9998856e-01],
       [4.4113831e-06],
       [6.1307808e-05],
       [9.9979514e-01],
       [9.9811792e-01],
       [4.7327758e-06]], dtype=float32)
In [12]:
# 获取X_test列表
gen = ImageDataGenerator()
test_generator = gen.flow_from_directory("./test", (224, 224), shuffle=False, batch_size=16, class_mode=None)

file_list = []

for i, file in enumerate(test_generator.filenames):
    index_1 = file.rfind('\\')
    index_2 = file.rfind('.')
    file_name = int(file[index_1+1:index_2])
    file_list.append(file_name)

file_list = np.array(file_list)
Found 12500 images belonging to 1 classes.
In [20]:
# 随机展示9条预测结果

random.seed(2018)
test_images_path = './mydata2/test1/'

plt.figure(figsize=(12,12))
plt.subplots_adjust(wspace=0.2, hspace=0.4)
for location in range(331, 340):
    plt.subplot(location)
    for i in np.random.randint(low=0,high=12500,size=1):
        if predictions[i] >= 0.5: 
#             print(predictions[i])
            title = 'I am {:.3%} sure this is a Dog'.format(float(predictions[i]))
        else: 
#             print(predictions[i])
            title = 'I am {:.3%} sure this is a Cat'.format(float(1-predictions[i]))
        plt.title(title)  
        file = test_generator.filenames[i]
        img = cv2.imread(test_images_path + file)
        b,g,r = cv2.split(img) # 改变图片通道:BGR → RGB
        rgb_img = cv2.merge([r,g,b])
        plt.imshow(rgb_img)
    
plt.show()

由于测试集数据不是按照文件名大小依次排序的,因此需要对predictions做排序。排序思路是利用X_test的列表排序对predictions做排序。

In [57]:
# 获取file_list正序排序,并根据此顺序对predictions排序
file_list_index = np.argsort(file_list)

p = np.zeros((len(predictions)), dtype=np.float32)
for key,value in enumerate(file_list_index):
    p[key] = predictions[value]
    
predictions = p
In [59]:
predictions[:10]
Out[59]:
array([9.9984729e-01, 9.9999475e-01, 9.9998832e-01, 9.9989271e-01,
       7.1073515e-07, 2.3044984e-06, 2.0153166e-05, 8.1204325e-06,
       6.5342442e-07, 6.4735364e-06], dtype=float32)
In [60]:
#由于kaggle采用log_loss作为评分标准,参考log_loss对无穷大问题的处理,使用clip对预测值空间做限制,能显著提高kaggle分数

predictions = predictions.clip(min=0.005, max=0.995)
In [61]:
with open('submission_0.0069.csv','w') as f:
    f.write('id,label\n')
            
with open('submission_0.0069.csv','a') as f:
    num = len(predictions)
    pred = 0
    
    for i in tqdm(range(0,num)):
        pred = predictions[i]
        
        f.write('{},{}\n'.format(i+1, pred))
        
f.close()
print('file closed!')
100%|██████████| 12500/12500 [00:00<00:00, 367076.48it/s]
file closed!